from interaction_instructions import *
from agent_personas import *
import random
import anthropic
from keys import *
import pickle
import os
from transformers import pipeline
import re
import argparse

# CONSTANTS
num_questions = 3 # k
convo_history = []

def parse_lines_yaml(yaml):
    yaml = yaml.replace(":\n", ": ")
    line_exp = []
    for line in yaml.split('\n'):
        try:
            line.index('_')
            i = line.index(': ')
            line_exp.append(line[i+2:])
        except:
            continue

    return line_exp

def parse_yes_no(yaml, x):
    # temp = yaml[yaml.index('Yes') + 5:]
    # answer = temp[:temp.index('\n')] == 'True'

    # # sanity check
    # ttemp = temp[temp.index('\n') + 5:]
    # temp_ans = ttemp[:ttemp.index('\n')] == 'True'

    # assert answer is not temp_ans

    # return answer
    explanation = re.findall(r'explanation: (.+)', yaml, re.IGNORECASE)[0]
    answer = re.findall(fr'{x}: (.+)', yaml, re.IGNORECASE)[0] == 'True'

    print(f'EXPLANATION: {explanation}')
    print(f'ANSWER: {answer}')
    return explanation, answer

def log(text):
    with open(f'instruct_not_assist/{FILE_NAME}/log.txt', 'a+') as f:
        f.write('\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n')
        f.write(text)
        f.write('\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n')

class Verifier():
    def __init__(self):
        self.init_prompt = verifier_persona + problem_statement + correct_code
        self.model = anthropic.Anthropic(
            api_key=claude_api_key
        )

    def prompt_verifier(self, prompt):
        final_prompt = self.init_prompt + prompt
        message = self.model.messages.create(
            model="claude-3-sonnet-20240229",
            max_tokens=1024,
            messages=[
                {"role": "user", "content": final_prompt}
            ]
        )

        print('prompted verifier')
        log('--- TO VERIFIER: ' + prompt)
        log(f'--- FROM VERIFIER: {message.content[0].text}')

        return message.content[0].text
        # response = ''
        # return response

    def prompt_verifier_state(self, prompt):
        final_prompt = verifier_persona + prompt + problem_statement + correct_code + buggy_code
        message = self.model.messages.create(
            model="claude-3-sonnet-20240229",
            max_tokens=1024,
            messages=[
                {"role": "user", "content": final_prompt}
            ]
        )

        print('prompted verifier state')
        log('--- TO VERIFIER: ' + prompt)
        log(f'--- FROM VERIFIER: {message.content[0].text}')

        return message.content[0].text
        # response = ''
        # return response
    
    def get_state_repr(self):
        prompt = v2v_get_state_repr + yaml_state_repr
        response = self.prompt_verifier_state(prompt)

        return parse_lines_yaml(response)
    
    def assess_misunderstanding(self, instructor_exp, student_exp):
        # check for discrepancy
        x = "is_discrepancy"
        prompt = i2i_discrepancy(instructor_exp, student_exp) + yaml_yes_no(x)
        response = self.prompt_verifier(prompt)

        return parse_yes_no(response, x)
    
    def assess_understanding_of_prev_level(self, instructor_question, discrepancy, current_response):
        x = "does_student_understand"
        prompt = i2i_prevq(instructor_question, current_response, discrepancy) + yaml_yes_no(x)
        response = self.prompt_verifier(prompt)
        return parse_yes_no(response, x)
    
    def assess_understanding_of_curr_level(self, instructor_question, current_response):
        x = "did_student_answer_correctly"
        prompt = i2i_correct_response(instructor_question, current_response) + yaml_yes_no(x)
        response = self.prompt_verifier(prompt)
        return parse_yes_no(response, x)

    def assess_state_level_understanding(self, instructor_question, current_response, target_representation):
        x = "does_student_address_target"
        prompt = i2i_address_target(instructor_question, current_response, target_representation) + yaml_yes_no(x)
        response = self.prompt_verifier(prompt)
        return parse_yes_no(response, x)

    def update_code(self, student_bug_fixes):
        final_prompt = i2i_apply_bug_fixes(buggy_code, student_bug_fixes) + yaml_code_gen
        message = self.model.messages.create(
            model="claude-3-sonnet-20240229",
            max_tokens=1024,
            messages=[
                {"role": "user", "content": final_prompt}
            ]
        )

        print('prompted verifier')
        log('--- TO VERIFIER: ' + final_prompt)
        log(f'--- FROM VERIFIER: {message.content[0].text}')

        yaml = message.content[0].text.replace(":\n", ": ").replace(": \n", ": ")
        yaml = re.findall(r'corrected_code:([\S\s]*)', yaml, re.IGNORECASE)[0]
        return yaml
    
class Instructor():
    def __init__(self):
        self.init_prompt = instructor_persona + bug_fixes + problem_statement + buggy_code + bug_description
        self.model = anthropic.Anthropic(
            api_key=claude_api_key
        )


    def prompt_instructor(self, prompt):
        final_prompt = self.init_prompt + prompt
        message = self.model.messages.create(
            model="claude-3-sonnet-20240229",
            max_tokens=1024,
            messages=[
                {"role": "user", "content": final_prompt}
            ]
        )
        print('prompted instructor')
        log('--- TO INSTRUCTOR: ' + prompt)
        log(f'--- FROM INSTRUCTOR: {message.content[0].text}')
        return message.content[0].text
        # response = ''
        # return response
        
    
    def generate_candidate_questions(self, level, prev_qs=None, discrepancy=None, explanation=""):
        # conditional question genration
        if level == 0:
            prompt = i2i_cqg_code_one(discrepancy) + yaml_cqg
        else: # same level of questions
            if len(explanation):
                prompt = i2i_qg_discrepancy_one(prev_qs, convo_history, bug_fixes, bug_description, discrepancy, explanation) + yaml_cqg
            else: # next level of questions
                prompt = i2i_qg_one(prev_qs, convo_history, bug_fixes, bug_description, discrepancy) + yaml_cqg

        candidate_questions = self.prompt_instructor(prompt)
        candidate_questions = parse_lines_yaml(candidate_questions)

        return candidate_questions
    
    # def get_exp(self, line_no):
    #     prompt = instructor_gt_explanation(line_no) + yaml_code_explanation_one_line + correct_code
    #     response = self.prompt_instructor(prompt)
    #     return parse_lines_yaml(response)[0]
    
    def verify_code(code):
        return True # TODO
    
class Student():
    def __init__(self):
        self.model = pipeline("text-generation", 
                              model="mistralai/Mistral-7B-Instruct-v0.2", 
                              trust_remote_code=True, 
                              device_map='auto')
        self.init_prompt = student_persona + problem_statement + buggy_code
    
    def prompt_student(self, prompt, suffix=""):
        final_prompt = "<s>[INST]" + self.init_prompt + prompt + "[/INST]" + suffix
        response = self.model(final_prompt, do_sample=True, top_k=10, num_return_sequences=1, max_new_tokens=200)[0]
        
        print('prompted student')
        log('--- TO STUDENT: ' + prompt)
        log('--- FROM STUDENT: ' + response['generated_text'].split('[/INST]')[-1])
        return response['generated_text']
        # response = ''
        # return response
    
    def parse_student_exp(self, yaml):
        yaml = yaml.replace(":\n", ": ").replace(": \n", ": ")
        yaml = yaml.split('[/INST]')[-1]
        yaml = re.findall(r'line_explanation: (.+)\n', yaml, re.IGNORECASE)[0] #yaml.split('[/INST] line_explanation: ')[1]
        if "\n" in yaml:
            yaml = yaml.split("\n")[0]
        print(yaml)
        return yaml
    
    def parse_student_answer(self, yaml):
        yaml = yaml.replace(":\n", ": ").replace(": \n", ": ")
        yaml = yaml.split('[/INST]')[-1]
        yaml = re.findall(r'student_answer:([\S\s]*)', yaml, re.IGNORECASE)[0]
        if "\n" in yaml:
            yaml = yaml.split("\n")[0]
        print(yaml)
        return yaml

    def generate_bug_fixes(self, state_representation):
        temp = ""
        for x, y in zip(state_representation[0], state_representation[1]):
            if y:
                temp += f"{x}\n" 
        prompt = i2s_generate_bug_fixes('\n'.join(convo_history), temp)
        yaml = self.prompt_student(prompt)
        yaml = yaml.split('[/INST]')[-1]
        yaml = re.findall(r'bug_fix_.:\s*(.*)', yaml, re.IGNORECASE)
        return yaml



    def ask_student(self, question):
        response = self.prompt_student(question + yaml_student_answer, "\nstudent_answer: ")
        return self.parse_student_answer(response)
    
    # def get_exp(self, line_no, suffix=""):
    #     prompt = ask_for_student_explanation(line_no) + yaml_code_explanation_one_line
    #     response = self.prompt_student(prompt, suffix)
    #     return self.parse_student_exp(response)
        

def fix_misunderstanding(student: Student, instructor: Instructor, verifier: Verifier, state_representation, target_rep):
    level = 0 # level 0 is asking about misunderstandings with code
    level_questions = {}
    level_indices = {}
    is_student_done = False
    level_explanations = {}

    # error handling
    level_questions[-1] = [f'Can you walk through the logic of your code?']

    # get explanations
    # student_exp = student.get_exp(line_no, "\nline_explanation: ")
    # instructor_exp = instructor.get_exp(line_no)

    prefix = ""

    # discrepancy, _ = verifier.assess_misunderstanding(instructor_exp, student_exp)
    candidate_questions =instructor.generate_candidate_questions(level, discrepancy=target_rep)

    while not is_student_done:
        # assess core misunderstanding -> get the k questions to ask student     
        if level not in level_questions.keys():
            level_questions[level] = candidate_questions
            level_indices[level] = 0 # setting i

        # get student answer for question i at level l
        instructor_question = prefix + level_questions[level][level_indices[level]]
        convo_history.append("Instructor: " + instructor_question)
        student_question_response = student.ask_student(instructor_question)
        convo_history.append("Student: " + student_question_response)
        level_indices[level] += 1
        
        # use verifier to see if student understands the curr level questions
        clu_explanation, is_curr_level_understand = verifier.assess_understanding_of_curr_level(instructor_question, student_question_response)
        if is_curr_level_understand:
            # TODO: ADD CONDITION THAT TAKES IN TARGET ATTRIBUTE AND SEE IF IT IS RESOLVED --> STOP
            # [[attribute_1, 2, ...], [False, False,..]]
            idx = state_representation[0].index(target_rep)
            for i in range(idx, len(state_representation[0])):
                # is this state attribute actually resolved?
                exp, flag = verifier.assess_state_level_understanding(instructor_question, student_question_response, state_representation[0][i])
                # if it is resolved -> update state representation and move onto next attribute
                # if it is not resolved -> if no progression (i == idx), just prefix_next_level;
                #                       -> if progression (i > idx), ask student to generate bug fixes + ask instructor/verifier to update the code -> return state_repr + new code
                if flag:
                    state_representation[1][i] = True
                else:
                    level_explanations[level] = exp
                    if i == idx:
                        level += 1
                        prefix = prefix_next_level
                    else:
                        is_student_done = True
                    break
            if state_representation[1][-1]:
                is_student_done = True
        else:
            # no change in level
            prefix = prefix_same_level
            level_explanations[level] = clu_explanation

        with open(f'instruct_not_assist/{FILE_NAME}/convo.txt', 'a+') as f:
            try:
                f.write(f'Instructor: {instructor_question}\n')
                f.write(f'Student: {student_question_response}\n')
                # f.write(f'Verifier: \n\t Previous Level: {is_prev_level_understand}, {plu_explanation}\n')
                f.write(f'\tCurrent Level: {is_curr_level_understand}, {clu_explanation}\n')
            except:
                f.write(f'\tCurrent Level: N/A, N/A\n')
        
        # generate new questions
        need_append = prefix == prefix_same_level and (level_indices[level] >= len(level_questions[level]))
        if not is_student_done and (is_curr_level_understand or need_append):
            if prefix == prefix_same_level:
                candidate_questions = instructor.generate_candidate_questions(level, prev_qs='\n'.join(level_questions[level]), explanation=level_explanations[max(0, level-1)])
            else:
                candidate_questions = instructor.generate_candidate_questions(level, prev_qs='\n'.join(level_questions[max(0, level - 1)]))

            if need_append:
                level_questions[level].extend(candidate_questions)

    # generate new code
    student_bug_fixes = student.generate_bug_fixes(state_representation)
    with open(f'instruct_not_assist/{FILE_NAME}/bug_fixes.txt', 'a+') as f:
        f.write(f'{student_bug_fixes}\n')
    new_code = verifier.update_code(student_bug_fixes)

    return state_representation, new_code     

def run():
    instructor = Instructor()
    student = Student()
    verifier = Verifier()

    did_student_understand = False

    # get state representation based on student initial progress
    # starting point: buggy code
    # how to we get to ending point: correct code
    state_attributes = verifier.get_state_repr()
    state_repr = [state_attributes, [False]*len(state_attributes)]
    temp = ""
    for x, y in zip(state_repr[0], state_repr[1]):
        temp += f"{x}, {y}\n"
    log(f"State Representation: {temp}")

    # TODO: RATHER THAN RESOLVING MISUNDERSTANDING PER BUGGY LINE OF CODE
    # RESOLVE BASED ON EACH ATTRIBUTE OF STATE REPRESENTATION
    # GENERATE NEW CODE AFTER EACH RESOLUTION AND

    while not did_student_understand:
        if False in state_repr[1]:
            target = state_repr[0][state_repr[1].index(False)]
            state_repr, student_new_code = fix_misunderstanding(student, instructor, verifier, state_repr, target)
            buggy_code = student_new_code

            temp = ""
            for x, y in zip(state_repr[0], state_repr[1]):
                temp += f"{x}, {y}\n"
            log(f"State Representation: {temp}")
        else:
            did_student_understand = True
            with open(f'instruct_not_assist/{FILE_NAME}/correct_code.txt', 'w+') as f:
                f.write(buggy_code)

# TODO get problem title from cmd and fill out
# problem_statement = """
# ---
# problem:
# Write a function `search(x: int, seq: List[int]) -> int` that returns the index of the first occurrence of `x` in `seq`. If `x` is not in `seq`, return the index where `x` should be inserted to keep `seq` sorted. Assume that `seq` is sorted in ascending order.
# ## Example Cases:
# ```
# search(5, [-1, 5, 8, 10, 12]) => 1
# search(-2, [-1, 57, 65]) => 0
# search(0, [-120, 60, 78, 100]) => 1
# search(77, [-100, -50, 5, 44, 66, 76, 99]) => 6
# search(55, [-99, -2, 0]) => 3
# ```
# ---
# """

# bug_description = """
# ---
# bug_desc:
# On line 3, the function only checks if `x` is less than `seq[i]` and then returns the index `i` where `x` should be inserted. When `x` is in `seq` at position `i`, the function returns the next index `i + 1` instead of the current index `i`.
# ---
# """

# buggy_code = """
# ---
# buggy_code:
# 1. def search(x, seq):
# 2.  for i in range(len(seq)):
# 3.    if x < seq[i]:
# 4.      return i
# 5.  return len(seq)
# ---
# """

# bug_fixes = """
# ---
# bug_fix:
# Replace `<` with `<=` on line 3
# ---
# """

# correct_code = """
# ---
# correct_code:
# 1. def search(x, seq):
# 2.  for i in range(len(seq)):
# 3.    if x <= seq[i]:
# 4.      return i
# 5.  return len(seq)
# ---
# """ 

# line_no = 3 # TODO: the fix doesn't have to be on the same line!! We give a line no for the explanation to both student and instructor
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--file', type=str, default='data_pkls/0_5_fibonacci_conversational_thread_1.pkl')
    args = parser.parse_args()

    FILE_NAME = args.file[args.file.index('/') + 1:].split('.')[0]

    try:
        os.mkdir(f'instruct_not_assist/{FILE_NAME}')
    except OSError:
        pass

    extracted_data = pickle.load(open(args.file, 'rb'))
    problem_statement = extracted_data['problem']
    buggy_code = extracted_data['bug_code']
    bug_fixes = extracted_data['bug_fixes']
    bug_description = extracted_data['bug_desc'] # not a typo
    correct_code = extracted_data['code']

    run()

